{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "computation_graph_with_solution.ipynb",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "WJ-h8gN6BCGW",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Computation Graphs and Back Propagation"
      ]
    },
    {
      "metadata": {
        "id": "vSVob2t1BCGb",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "from __future__ import division\n",
        "import numpy as np"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Ik9ypdkpBCGw",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Define a Set of Computation Nodes\n",
        "\n",
        "| &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Name  &nbsp;&nbsp;&nbsp;&nbsp;  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;    |   &nbsp;&nbsp; &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Fomula &nbsp;&nbsp;&nbsp;&nbsp;   &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;   &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;    |   &nbsp;&nbsp;  Gradients  |\n",
        "|:-------------:|:------------- |:----- |\n",
        "| Linear      | $y=x^T\\cdot W+b$ | $\\frac{\\partial \\mathcal{L}}{\\partial x}=W\\cdot\\frac{\\partial \\mathcal{L}}{\\partial y}\\\\\\frac{\\partial \\mathcal{L}}{\\partial W}=x^T\\cdot\\frac{\\partial \\mathcal{L}}{\\partial y}\\\\\\frac{\\partial \\mathcal{L}}{\\partial b}=\\frac{\\partial \\mathcal{L}}{\\partial y}$ |\n",
        "| Sigmoid     | $y=\\frac{1}{1+e^{-x}}$  | $\\frac{\\partial \\mathcal{L}}{\\partial x}=\\frac{\\partial \\mathcal{L}}{\\partial y}(1-y)y$ |\n",
        "| Softmax     | $y_j=\\frac{e^{x_j}}{\\sum\\limits_i e^{x_i}}$ | $\\frac{\\partial \\mathcal{L}}{\\partial x_j}=\\frac{\\partial \\mathcal{L}}{\\partial y_j}y_j-y_j\\sum\\limits_i \\frac{\\partial \\mathcal{L}}{\\partial y_i}y_i$ |\n",
        "| CrossEntropy | $y=-\\sum\\limits_i p_i \\log(x_i)$ | $\\frac{\\partial \\mathcal{L}}{\\partial x_i}=-\\frac{\\partial \\mathcal{L}}{\\partial y}\\frac{p_i}{x_i}$ |\n",
        "| Mean  | $y=\\frac{1}{N}\\sum\\limits_i x_i$ | $\\frac{\\partial \\mathcal{L}}{\\partial x_i}=\\frac{1}{N}\\frac{\\partial \\mathcal{L}}{\\partial y}$ |"
      ]
    },
    {
      "metadata": {
        "id": "o6bEV2CvBCG0",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class Linear(object):\n",
        "    '''\n",
        "    linear node f(x) = xW + b.\n",
        "    \n",
        "    Attributes:\n",
        "        parameters (list): variables (input nodes) that directly feed into this node, W and b.\n",
        "        parameters_deltas (list): gradients for parameters.\n",
        "    '''\n",
        "    def __init__(self, input_shape, output_shape, mean=0, variance=0.01):\n",
        "        self.parameters = [mean + variance * np.random.randn(input_shape, output_shape),\n",
        "                           mean + variance * np.random.randn(output_shape)]\n",
        "        self.parameters_deltas = [None, None]\n",
        "\n",
        "    def forward(self, x, *args):\n",
        "        '''function itself.'''\n",
        "        self.x = x\n",
        "        return np.matmul(x, self.parameters[0]) + self.parameters[1]\n",
        "\n",
        "    def backward(self, delta):\n",
        "        '''\n",
        "        Args:\n",
        "            delta (ndarray): gradient of L with repect to node's output, dL/dy.\n",
        "            \n",
        "        Returns:\n",
        "            ndarray: gradient of L with respect to node's input, dL/dx\n",
        "        '''\n",
        "        self.parameters_deltas[0] = self.x.T.dot(delta)\n",
        "        self.parameters_deltas[1] = np.sum(delta, 0)\n",
        "        return delta.dot(self.parameters[0].T)\n",
        "\n",
        "\n",
        "class F(object):\n",
        "    '''base class for functions with no parameters.'''\n",
        "    def __init__(self):\n",
        "        self.parameters = []\n",
        "        self.parameters_deltas = []\n",
        "\n",
        "\n",
        "class Sigmoid(F):\n",
        "    def forward(self, x, *args):\n",
        "        self.x = x\n",
        "        self.y = 1.0 / (1.0 + np.exp(-x))\n",
        "        return self.y\n",
        "\n",
        "    def backward(self, delta):\n",
        "        return delta * ((1 - self.y) * self.y)\n",
        "\n",
        "\n",
        "class Softmax(F):\n",
        "    def forward(self, x, *args):\n",
        "        self.x = x\n",
        "        xtmp = x - x.max(axis=-1, keepdims=True) # to avoid overflow\n",
        "        exps = np.exp(xtmp)\n",
        "        self.y = exps / exps.sum(axis=-1, keepdims=True)\n",
        "        return self.y\n",
        "\n",
        "    def backward(self, delta):\n",
        "        return delta * self.y - self.y * (delta * self.y).sum(axis=-1, keepdims=True)\n",
        "\n",
        "\n",
        "class CrossEntropy(F):\n",
        "    def forward(self, x, p, *args):\n",
        "        # p is target probability. In MNIST dataset, it represents a one-hot label.\n",
        "        self.p = p\n",
        "        self.x = x\n",
        "        y = -p * np.log(np.maximum(x, 1e-15))\n",
        "        return y.sum(-1)\n",
        "\n",
        "    def backward(self, delta):\n",
        "        return -delta[..., None] * 1. / np.maximum(self.x, 1e-15) * self.p\n",
        "\n",
        "\n",
        "class Mean(F):\n",
        "    def forward(self, x, *args):\n",
        "        self.x = x\n",
        "        return x.mean()\n",
        "\n",
        "    def backward(self, delta):\n",
        "        return delta * np.ones(self.x.shape) / np.prod(self.x.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "y9CtKqpmBCHB",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Sanity Check"
      ]
    },
    {
      "metadata": {
        "id": "zoa9gxaqBCHF",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def numdiff(node, x, var, y_delta, delta, args):\n",
        "    '''numerical differenciation.'''\n",
        "    var_raveled = var.ravel()\n",
        "\n",
        "    var_delta_list = []\n",
        "    for ix in range(len(var_raveled)):\n",
        "        var_raveled[ix] += delta / 2.\n",
        "        yplus = node.forward(x, *args)\n",
        "        var_raveled[ix] -= delta\n",
        "        yminus = node.forward(x, *args)\n",
        "        var_delta = ((yplus - yminus) / delta * y_delta).sum()\n",
        "        var_delta_list.append(var_delta)\n",
        "\n",
        "        # restore changes\n",
        "        var_raveled[ix] += delta / 2.\n",
        "    return np.array(var_delta_list)\n",
        "\n",
        "\n",
        "def gradient_test(node, x, args=(), delta=0.01, precision=1e-3):\n",
        "    '''\n",
        "    perform sanity check for a node,\n",
        "    raise an assertion error if failed to pass all sanity checks.\n",
        "\n",
        "    Args:\n",
        "        node (obj): user defined neural network node.\n",
        "        x (ndarray): input array.\n",
        "        args: additional arguments for forward function.\n",
        "        delta: the strength of perturbation used in numdiff.\n",
        "        precision: the required precision of gradient (usually introduced by numdiff).\n",
        "    '''\n",
        "    y = node.forward(x, *args)\n",
        "    # y_delta is the gradient passed from the next node, i.e. dL/dy.\n",
        "    y_delta = np.random.randn(*y.shape)\n",
        "    x_delta = node.backward(y_delta)\n",
        "\n",
        "    for var, var_delta in zip([x] + node.parameters, [x_delta] + node.parameters_deltas):\n",
        "        var_delta_num = numdiff(node, x, var, y_delta, delta, args)\n",
        "        np.testing.assert_allclose(var_delta_num.reshape(\n",
        "            *var_delta.shape), var_delta, atol=precision, rtol=precision)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "eAAWQzfwBCHT",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 105
        },
        "outputId": "ba28defd-5024-4456-eb90-6e909a694482",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1519078177218,
          "user_tz": -60,
          "elapsed": 622,
          "user": {
            "displayName": "刘金国",
            "photoUrl": "//lh3.googleusercontent.com/-lDAT81T3HSE/AAAAAAAAAAI/AAAAAAAAAgw/eH3JEob7M1Y/s50-c-k-no/photo.jpg",
            "userId": "116824001998056121289"
          }
        }
      },
      "cell_type": "code",
      "source": [
        "np.random.seed(5)\n",
        "for node in [Linear(6, 4), Sigmoid(), Softmax(), Mean()]:\n",
        "    print('checking %s' % node.__class__.__name__)\n",
        "    x = np.random.uniform(size=(5, 6))\n",
        "    gradient_test(node, x)\n",
        "\n",
        "# we take special care of cross entropy node here,\n",
        "# it takes an additional parameter p\n",
        "node = CrossEntropy()\n",
        "print('checking %s' % node.__class__.__name__)\n",
        "p = np.random.uniform(0.1, 1, [5, 6])\n",
        "p = p / p.sum(axis=-1, keepdims=True)\n",
        "x = np.random.uniform(0.1, 1, [5, 6])\n",
        "x = x / x.sum(axis=-1, keepdims=True)\n",
        "gradient_test(node, x, args=(p,), precision=1e-1)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "checking Linear\n",
            "checking Sigmoid\n",
            "checking Softmax\n",
            "checking Mean\n",
            "checking CrossEntropy\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "M2gJ6_HtEIfp",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Load MNIST data"
      ]
    },
    {
      "metadata": {
        "id": "frJZmkq1BCHp",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "import subprocess\n",
        "import os\n",
        "import struct\n",
        "\n",
        "\n",
        "def load_MNIST():\n",
        "    '''\n",
        "    download and unpack MNIST data.\n",
        "    \n",
        "    Returns:\n",
        "        tuple: length is 4. They are training set data, training set label,\n",
        "            test set data and test set label.\n",
        "    '''\n",
        "    base = \"http://yann.lecun.com/exdb/mnist/\"\n",
        "    objects = ['t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte',\n",
        "               'train-images-idx3-ubyte', 'train-labels-idx1-ubyte']\n",
        "    end = \".gz\"\n",
        "    path = \"data/raw/\"\n",
        "    cmd = [\"mkdir\", \"-p\", path]\n",
        "    subprocess.check_call(cmd)\n",
        "    print('Downloading MNIST dataset. Please do not stop the program\\\n",
        "during the download. If you do, remove `data` folder and try again.')\n",
        "    for obj in objects:\n",
        "        if not os.path.isfile(path + obj):\n",
        "            cmd = [\"wget\", base + obj + end, \"-P\", path]\n",
        "            subprocess.check_call(cmd)\n",
        "            cmd = [\"gzip\", \"-d\", path + obj + end]\n",
        "            subprocess.check_call(cmd)\n",
        "\n",
        "    def unpack(filename):\n",
        "        '''unpack a single file.'''\n",
        "        with open(filename, 'rb') as f:\n",
        "            _, _, dims = struct.unpack('>HBB', f.read(4))\n",
        "            shape = tuple(struct.unpack('>I', f.read(4))\n",
        "                          [0] for d in range(dims))\n",
        "            data = np.fromstring(f.read(), dtype=np.uint8).reshape(shape)\n",
        "            return data\n",
        "\n",
        "    # load objects\n",
        "    data = []\n",
        "    for name in objects:\n",
        "        name = path + name\n",
        "        data.append(unpack(name))\n",
        "    labels = np.zeros([data[1].shape[0], 10])\n",
        "    for i, iterm in enumerate(data[1]):\n",
        "        labels[i][iterm] = 1\n",
        "    data[1] = labels\n",
        "    labels = np.zeros([data[3].shape[0], 10])\n",
        "    for i, iterm in enumerate(data[3]):\n",
        "        labels[i][iterm] = 1\n",
        "    data[3] = labels\n",
        "    return data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "rfBw8SUTBCH0",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def random_draw(data, label, batch_size):\n",
        "    '''\n",
        "    random draw a batch of data and label.\n",
        "    \n",
        "    Args:\n",
        "        data (ndarray): dataset with the first axis the batch dimension.\n",
        "        label (ndarray): one-hot label for dataset,\n",
        "            for example, 3 is [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]\n",
        "        batch_size (int): size of batch, the number of data to draw.\n",
        "    '''\n",
        "    perm = np.random.permutation(data.shape[0])\n",
        "    data_b = data[perm[:batch_size]]\n",
        "    label_b = label[perm[:batch_size]]\n",
        "    return data_b.reshape([data_b.shape[0], -1]) / 255.0, label_b"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "WS4eNe71ERQY",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Run the training"
      ]
    },
    {
      "metadata": {
        "id": "wbtSgb1yBCIA",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            },
            {
              "item_id": 2
            },
            {
              "item_id": 6
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 319
        },
        "outputId": "00bb845d-f7bc-4b10-b565-72b8a34f50a0",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1519078186701,
          "user_tz": -60,
          "elapsed": 5037,
          "user": {
            "displayName": "刘金国",
            "photoUrl": "//lh3.googleusercontent.com/-lDAT81T3HSE/AAAAAAAAAAI/AAAAAAAAAgw/eH3JEob7M1Y/s50-c-k-no/photo.jpg",
            "userId": "116824001998056121289"
          }
        }
      },
      "cell_type": "code",
      "source": [
        "np.random.seed(5)\n",
        "batch_size = 100\n",
        "learning_rate = 0.5\n",
        "dim_img = 784\n",
        "num_digit = 10\n",
        "# an epoch means running through the training set roughly once\n",
        "num_epoch = 10\n",
        "train_data, train_label, test_data, test_label = load_MNIST()\n",
        "num_iteration = len(train_data) // batch_size\n",
        "\n",
        "\n",
        "def match_ratio(result, label):\n",
        "    '''the ratio of result matching target.'''\n",
        "    label_p = np.argmax(result, axis=1)\n",
        "    label_t = np.argmax(label, axis=1)\n",
        "    ratio = np.sum(label_p == label_t) / label_t.shape[0]\n",
        "    return ratio\n",
        "\n",
        "\n",
        "lossfunc = CrossEntropy()\n",
        "# define a list as a network, nodes are chained up\n",
        "net = [Linear(dim_img, num_digit), Softmax(), lossfunc, Mean()]\n",
        "\n",
        "\n",
        "def net_forward(x, label):\n",
        "    '''forward function for this sequencial network.'''\n",
        "    for node in net:\n",
        "        if node is lossfunc:\n",
        "            result = x\n",
        "            x = node.forward(x, label)\n",
        "        else:\n",
        "            x = node.forward(x)\n",
        "    return result, x\n",
        "\n",
        "\n",
        "def net_backward():\n",
        "    '''backward function for this sequencial network.'''\n",
        "    y_delta = 1.0\n",
        "    for node in net[::-1]:\n",
        "        y_delta = node.backward(y_delta)\n",
        "    return y_delta\n",
        "\n",
        "# display test loss before training\n",
        "x, label = random_draw(test_data, test_label, 1000)\n",
        "result, loss = net_forward(x, label)\n",
        "print('Before Training.\\nTest loss = %.4f, correct rate = %.3f' % (loss, match_ratio(result, label)))\n",
        "\n",
        "for epoch in range(num_epoch):\n",
        "    for j in range(num_iteration):\n",
        "        x, label = random_draw(train_data, train_label, batch_size)\n",
        "        result, loss = net_forward(x, label)\n",
        "\n",
        "        net_backward()\n",
        "\n",
        "        # update network parameters\n",
        "        for node in net:\n",
        "            for p, p_delta in zip(node.parameters, node.parameters_deltas):\n",
        "                p -= learning_rate * p_delta  # stochastic gradient descent\n",
        "\n",
        "    print(\"epoch = %d/%d, loss = %.4f, corret rate = %.2f\" %\n",
        "          (epoch, num_epoch, loss, match_ratio(result, label)))\n",
        "\n",
        "x, label = random_draw(test_data, test_label, 1000)\n",
        "result, loss = net_forward(x, label)\n",
        "print('After Training.\\nTest loss = %.4f, correct rate = %.3f' % (loss, match_ratio(result, label)))"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading MNIST dataset, please do not stop the programming    during downloading, if you do, remove `data` folder and try again.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:35: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Before Training.\n",
            "Test loss = 2.3216, correct rate = 0.136\n",
            "epoch = 0/10, loss = 0.3060, corret rate = 0.92\n",
            "epoch = 1/10, loss = 0.3170, corret rate = 0.94\n",
            "epoch = 2/10, loss = 0.3278, corret rate = 0.89\n",
            "epoch = 3/10, loss = 0.3740, corret rate = 0.90\n",
            "epoch = 4/10, loss = 0.4193, corret rate = 0.88\n",
            "epoch = 5/10, loss = 0.1788, corret rate = 0.93\n",
            "epoch = 6/10, loss = 0.3220, corret rate = 0.89\n",
            "epoch = 7/10, loss = 0.2003, corret rate = 0.96\n",
            "epoch = 8/10, loss = 0.2158, corret rate = 0.92\n",
            "epoch = 9/10, loss = 0.3298, corret rate = 0.85\n",
            "After Training.\n",
            "Test loss = 0.3418, correct rate = 0.905\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "49AapjrgH-gb",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Task\n",
        "Usually, **Softmax** and **CrossEntropy** are combined as **SoftmaxCrossEntropy**, this function can avoid potential rounding error, its back propagation is also more efficient.\n",
        "\n",
        "Write your own version of **SoftmaxCrossEntropy** below.\n",
        "* Let it pass the numdiff check\n",
        "* Replace it in the training procedure"
      ]
    },
    {
      "metadata": {
        "id": "CHyOFkjD2y-2",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# complete your task here"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "oWV_9MID2zj1",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Solution\n",
        "$$y=\\sum\\limits_i(\\log(\\sum\\limits_j e^{x_j})-x_i)p_i$$\n",
        "$$\\frac{\\partial \\mathcal{L}}{\\partial x_i}=\\frac{\\partial \\mathcal{L}}{\\partial y}(\\frac{e^{x_i}}{\\sum_j e^{x_j}}-p_i)$$"
      ]
    },
    {
      "metadata": {
        "id": "qMKmYM1JIlRR",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class SoftmaxCrossEntropy(F):\n",
        "    def forward(self, x, p, *args):\n",
        "        x = x - x.max(axis=-1, keepdims=True)  # avoid data overflow\n",
        "        rho = np.exp(x)\n",
        "        Z = rho.sum(axis=-1, keepdims=True)\n",
        "        self.p = p\n",
        "        self.rho = rho\n",
        "        self.Z = Z\n",
        "        return ((np.log(Z) - x) * p).sum(axis=-1)\n",
        "\n",
        "    def backward(self, delta):\n",
        "        y1 = self.rho / self.Z\n",
        "        return delta[...,None] * (y1 - self.p)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "48fY0VaS0HCm",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "f29774ab-ea7e-4bc8-d5a9-a95a39fb8159",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1517846681314,
          "user_tz": -480,
          "elapsed": 823,
          "user": {
            "displayName": "刘金国",
            "photoUrl": "//lh3.googleusercontent.com/-lDAT81T3HSE/AAAAAAAAAAI/AAAAAAAAAgw/eH3JEob7M1Y/s50-c-k-no/photo.jpg",
            "userId": "116824001998056121289"
          }
        }
      },
      "cell_type": "code",
      "source": [
        "node = SoftmaxCrossEntropy()\n",
        "print('checking %s' % node.__class__.__name__)\n",
        "p = np.random.uniform(0.1, 1, [5, 6])\n",
        "p = p / p.sum(axis=-1, keepdims=True)\n",
        "x = np.random.uniform(0.1, 1, [5, 6])\n",
        "gradient_test(node, x, args=(p,), precision=1e-3)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "checking SoftmaxCrossEntropy\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "g4_x31Xn1kpW",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            },
            {
              "item_id": 3
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 283
        },
        "outputId": "cbe5b9b3-0453-4e97-e929-fc4056b4a42c",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1517846683819,
          "user_tz": -480,
          "elapsed": 2244,
          "user": {
            "displayName": "刘金国",
            "photoUrl": "//lh3.googleusercontent.com/-lDAT81T3HSE/AAAAAAAAAAI/AAAAAAAAAgw/eH3JEob7M1Y/s50-c-k-no/photo.jpg",
            "userId": "116824001998056121289"
          }
        }
      },
      "cell_type": "code",
      "source": [
        "np.random.seed(5)\n",
        "batch_size = 100\n",
        "learning_rate = 0.5\n",
        "dim_img = 784\n",
        "num_digit = 10\n",
        "# an epoch means running through the training set roughly once\n",
        "num_epoch = 10\n",
        "train_data, train_label, test_data, test_label = load_MNIST()\n",
        "num_iteration = len(train_data) // batch_size\n",
        "\n",
        "\n",
        "def match_ratio(result, label):\n",
        "    '''the match rate between result and label'''\n",
        "    label_p = np.argmax(result, axis=1)\n",
        "    label_t = np.argmax(label, axis=1)\n",
        "    ratio = np.sum(label_p == label_t) / label_t.shape[0]\n",
        "    return ratio\n",
        "\n",
        "\n",
        "lossfunc = SoftmaxCrossEntropy()\n",
        "net = [Linear(dim_img, num_digit), lossfunc, Mean()]\n",
        "\n",
        "\n",
        "def net_forward(x, label):\n",
        "    for node in net:\n",
        "        if node is lossfunc:\n",
        "            x = node.forward(x, label)\n",
        "            result = node.rho\n",
        "        else:\n",
        "            x = node.forward(x)\n",
        "    return result, x\n",
        "\n",
        "\n",
        "def net_backward(y_delta):\n",
        "    for node in net[::-1]:\n",
        "        y_delta = node.backward(y_delta)\n",
        "    return y_delta\n",
        "\n",
        "# display test loss before training\n",
        "x, label = random_draw(test_data, test_label, 1000)\n",
        "result, loss = net_forward(x, label)\n",
        "print('Before Training.\\nTest loss = %.4f, correct rate = %.3f' % (loss, match_ratio(result, label)))\n",
        "\n",
        "for epoch in range(num_epoch):\n",
        "    for j in range(num_iteration):\n",
        "        x, label = random_draw(train_data, train_label, batch_size)\n",
        "        result, loss = net_forward(x, label)\n",
        "\n",
        "        net_backward(learning_rate)\n",
        "\n",
        "        # update network parameters\n",
        "        for node in net:\n",
        "            for p, p_delta in zip(node.parameters, node.parameters_deltas):\n",
        "                p -= p_delta  # stochastic gradient descent\n",
        "\n",
        "    print(\"epoch = %d/%d, loss = %.4f, corret rate = %.2f\" %\n",
        "          (epoch, num_epoch, loss, match_ratio(result, label)))\n",
        "\n",
        "x, label = random_draw(test_data, test_label, 1000)\n",
        "result, loss = net_forward(x, label)\n",
        "print('Test loss = %.4f, correct rate = %.3f' % (loss, match_ratio(result, label)))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:34: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Before Training.\n",
            "Test loss = 2.3216, correct rate = 0.136\n",
            "epoch = 0/10, loss = 0.3060, corret rate = 0.92\n",
            "epoch = 1/10, loss = 0.3170, corret rate = 0.94\n",
            "epoch = 2/10, loss = 0.3278, corret rate = 0.89\n",
            "epoch = 3/10, loss = 0.3740, corret rate = 0.90\n",
            "epoch = 4/10, loss = 0.4193, corret rate = 0.88\n",
            "epoch = 5/10, loss = 0.1788, corret rate = 0.93\n",
            "epoch = 6/10, loss = 0.3220, corret rate = 0.89\n",
            "epoch = 7/10, loss = 0.2003, corret rate = 0.96\n",
            "epoch = 8/10, loss = 0.2158, corret rate = 0.92\n",
            "epoch = 9/10, loss = 0.3298, corret rate = 0.85\n",
            "Test loss = 0.3418, correct rate = 0.905\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "b98ysE1t2HFr",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}